Source code for nlp_architect.models.gnmt.utils.iterator_utils

# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
# Changes Made from original:
#   import paths
# ******************************************************************************
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: skip-file
"""For loading data into NMT models."""
from __future__ import print_function

import collections

import tensorflow as tf

from nlp_architect.models.gnmt.utils import vocab_utils


__all__ = ["BatchedInput", "get_iterator", "get_infer_iterator"]


# NOTE(ebrevdo): When we subclass this, instances' __dict__ becomes empty.
[docs]class BatchedInput( collections.namedtuple("BatchedInput", ("initializer", "source", "target_input", "target_output", "source_sequence_length", "target_sequence_length"))): pass
[docs]def get_infer_iterator(src_dataset, src_vocab_table, batch_size, eos, src_max_len=None, use_char_encode=False): if use_char_encode: src_eos_id = vocab_utils.EOS_CHAR_ID else: src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32) src_dataset = src_dataset.map(lambda src: tf.string_split([src]).values) if src_max_len: src_dataset = src_dataset.map(lambda src: src[:src_max_len]) if use_char_encode: # Convert the word strings to character ids src_dataset = src_dataset.map( lambda src: tf.reshape(vocab_utils.tokens_to_bytes(src), [-1])) else: # Convert the word strings to ids src_dataset = src_dataset.map( lambda src: tf.cast(src_vocab_table.lookup(src), tf.int32)) # Add in the word counts. if use_char_encode: src_dataset = src_dataset.map( lambda src: (src, tf.to_int32( tf.size(src) / vocab_utils.DEFAULT_CHAR_MAXLEN))) else: src_dataset = src_dataset.map(lambda src: (src, tf.size(src))) def batching_func(x): return x.padded_batch( batch_size, # The entry is the source line rows; # this has unknown-length vectors. The last entry is # the source row size; this is a scalar. padded_shapes=( tf.TensorShape([None]), # src tf.TensorShape([])), # src_len # Pad the source sequences with eos tokens. # (Though notice we don't generally need to do this since # later on we will be masking out calculations past the true sequence. padding_values=( src_eos_id, # src 0)) # src_len -- unused batched_dataset = batching_func(src_dataset) batched_iter = batched_dataset.make_initializable_iterator() (src_ids, src_seq_len) = batched_iter.get_next() return BatchedInput( initializer=batched_iter.initializer, source=src_ids, target_input=None, target_output=None, source_sequence_length=src_seq_len, target_sequence_length=None)
[docs]def get_iterator(src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, batch_size, sos, eos, random_seed, num_buckets, src_max_len=None, tgt_max_len=None, num_parallel_calls=4, output_buffer_size=None, skip_count=None, num_shards=1, shard_index=0, reshuffle_each_iteration=True, use_char_encode=False): if not output_buffer_size: output_buffer_size = batch_size * 1000 if use_char_encode: src_eos_id = vocab_utils.EOS_CHAR_ID else: src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32) tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32) tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32) src_tgt_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset)) src_tgt_dataset = src_tgt_dataset.shard(num_shards, shard_index) if skip_count is not None: src_tgt_dataset = src_tgt_dataset.skip(skip_count) src_tgt_dataset = src_tgt_dataset.shuffle( output_buffer_size, random_seed, reshuffle_each_iteration) src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: ( tf.string_split([src]).values, tf.string_split([tgt]).values), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) # Filter zero length input sequences. src_tgt_dataset = src_tgt_dataset.filter( lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0)) if src_max_len: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src[:src_max_len], tgt), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) if tgt_max_len: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src, tgt[:tgt_max_len]), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) # Convert the word strings to ids. Word strings that are not in the # vocab get the lookup table's default_value integer. if use_char_encode: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (tf.reshape(vocab_utils.tokens_to_bytes(src), [-1]), tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)), num_parallel_calls=num_parallel_calls) else: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (tf.cast(src_vocab_table.lookup(src), tf.int32), tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)), num_parallel_calls=num_parallel_calls) src_tgt_dataset = src_tgt_dataset.prefetch(output_buffer_size) # Create a tgt_input prefixed with <sos> and a tgt_output suffixed with <eos>. src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src, tf.concat(([tgt_sos_id], tgt), 0), tf.concat((tgt, [tgt_eos_id]), 0)), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size) # Add in sequence lengths. if use_char_encode: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt_in, tgt_out: ( src, tgt_in, tgt_out, tf.to_int32(tf.size(src) / vocab_utils.DEFAULT_CHAR_MAXLEN), tf.size(tgt_in)), num_parallel_calls=num_parallel_calls) else: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt_in, tgt_out: ( src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in)), num_parallel_calls=num_parallel_calls) src_tgt_dataset = src_tgt_dataset.prefetch(output_buffer_size) # Bucket by source sequence length (buckets for lengths 0-9, 10-19, ...) def batching_func(x): return x.padded_batch( batch_size, # The first three entries are the source and target line rows; # these have unknown-length vectors. The last two entries are # the source and target row sizes; these are scalars. padded_shapes=( tf.TensorShape([None]), # src tf.TensorShape([None]), # tgt_input tf.TensorShape([None]), # tgt_output tf.TensorShape([]), # src_len tf.TensorShape([])), # tgt_len # Pad the source and target sequences with eos tokens. # (Though notice we don't generally need to do this since # later on we will be masking out calculations past the true sequence. padding_values=( src_eos_id, # src tgt_eos_id, # tgt_input tgt_eos_id, # tgt_output 0, # src_len -- unused 0)) # tgt_len -- unused if num_buckets > 1: def key_func(unused_1, unused_2, unused_3, src_len, tgt_len): # Calculate bucket_width by maximum source sequence length. # Pairs with length [0, bucket_width) go to bucket 0, length # [bucket_width, 2 * bucket_width) go to bucket 1, etc. Pairs with length # over ((num_bucket-1) * bucket_width) words all go into the last bucket. if src_max_len: bucket_width = (src_max_len + num_buckets - 1) // num_buckets else: bucket_width = 10 # Bucket sentence pairs by the length of their source sentence and target # sentence. bucket_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width) return tf.to_int64(tf.minimum(num_buckets, bucket_id)) def reduce_func(unused_key, windowed_data): return batching_func(windowed_data) batched_dataset = src_tgt_dataset.apply( tf.contrib.data.group_by_window( key_func=key_func, reduce_func=reduce_func, window_size=batch_size)) else: batched_dataset = batching_func(src_tgt_dataset) batched_iter = batched_dataset.make_initializable_iterator() (src_ids, tgt_input_ids, tgt_output_ids, src_seq_len, tgt_seq_len) = (batched_iter.get_next()) return BatchedInput( initializer=batched_iter.initializer, source=src_ids, target_input=tgt_input_ids, target_output=tgt_output_ids, source_sequence_length=src_seq_len, target_sequence_length=tgt_seq_len)